import argparse

import torch


def get_args():
    parser = argparse.ArgumentParser(description='RL')
    
    parser.add_argument(
        '--batch_size', type=int, default=50, help='CNFIL batch size (default: 50)')
    parser.add_argument(
        '--num_processes', type=int, default=1)
    parser.add_argument(
        '--experts_dir',
        default='./imitation/imitation_runs/classic/trajs/',
        help='directory that contains expert demonstrations for gail and fgail')
    parser.add_argument(
        '--lr', type=float, default=7e-4, help='learning rate (default: 7e-4)')
    parser.add_argument(
        '--gamma',
        type=float,
        default=0.99,
        help='discount factor for rewards (default: 0.99)')
    parser.add_argument(
        '--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument(
        '--log_interval', type=int, default=10, help='CNFIL logging interval')
    parser.add_argument(
        '--eval_interval', type=int, default=1, help='Evaluation interval')
    parser.add_argument(
        '--env_name',
        default='CartPole-v0', help='environment to train on (default: CartPole-v0)')
    parser.add_argument(
        '--log_dir',
        default='./tmp/gym/', help='directory to save agent logs (default: ./tmp/gym)')
    parser.add_argument(
        '--funcpoint',
        default=None, help='expert state model to use')
    parser.add_argument(
        '--eval_log_dir',
        default='./tmp/gym/eval/', help='directory to save agent logs (default: ./tmp/gym/eval/)')
    parser.add_argument(
        '--save_dir',
        default='./trained_models/',
        help='directory to save agent logs (default: ./trained_models/)')
    parser.add_argument(
        '--recurrent-policy',
        action='store_true', default=False, help='use a recurrent policy')
    parser.add_argument(
        '--adjoint',
        action='store_true', default=False, help='use adjoint or not')
    parser.add_argument(
        '--gpu', type=int, default=0, help='which GPU to use')
    parser.add_argument(
        '--niters', type=int, default=1000, help='iteration time')
    parser.add_argument(
        '--width', type=int, default=64)
    parser.add_argument(
        '--hidden_dim', type=int, default=32)
    parser.add_argument(
        '--max_step_num', type=int, default=1000)
    parser.add_argument("--log_wandb", type=int, default=1)
    parser.add_argument("--num_demo", type=int, default=4, help='number of expert demonstrations to train from')
    parser.add_argument("--sec", type=int, default=3)
    parser.add_argument("--buffer_num", type=int, default=1)
    
    parser.add_argument("--subsample_frequency", type=int, default=4)

    args = parser.parse_args()
    return args
